{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wR53lePHuiP-"
      },
      "source": [
        "# Finetune PaliGemma\n",
        "\n",
        "> *These models and code are not official Google products and were trained and released for research purposes.*\n",
        "\n",
        "\n",
        "**This notebook shows how to finetune PaliGemma on a vision-language task.**\n",
        "The training data consists of 90 pairs of images and long captions describing them.\n",
        "To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters.\n",
        "\n",
        " **This setup is illustrative**. In a real usecase, the amount of data, trainable parameters, training steps and hyper-parameters and obtained results could be significantly different.\n",
        "\n",
        "This notebook uses the model reference implementation from [big_vision](https://github.com/google-research/big_vision).\n",
        "and shows how to:\n",
        "\n",
        " * Install deps, download model checkpoint and training data.\n",
        " * Load the model onto GPU devices.\n",
        " * Prepare the input to the model for training and inference.\n",
        " * Finetune the model and inspect output in validation split."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6U0QUFveqSP2"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "DfxKb3F839Ks",
        "outputId": "d02e98d5-8334-463f-f529-6292dd73b04b",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m3.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.2/43.2 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Building wheel for ml_collections (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "# @title Fetch big_vision code and install dependencies.\n",
        "import os\n",
        "import sys\n",
        "\n",
        "# TPUs with\n",
        "if \"COLAB_TPU_ADDR\" in os.environ:\n",
        "  raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n",
        "\n",
        "# Fetch big_vision repository if python doesn't know about it and install\n",
        "# dependencies needed for this notebook.\n",
        "if not os.path.exists(\"big_vision_repo\"):\n",
        "  !git clone --quiet --branch=main --depth=1 \\\n",
        "     https://github.com/google-research/big_vision big_vision_repo\n",
        "\n",
        "# Append big_vision code to python import path\n",
        "if \"big_vision_repo\" not in sys.path:\n",
        "  sys.path.append(\"big_vision_repo\")\n",
        "\n",
        "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n",
        "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "azmRZvgGyhAb"
      },
      "source": [
        "### Configure your API key to access Kaggle\n",
        "\n",
        "To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.\n",
        "\n",
        "1. To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.\n",
        "1. In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n",
        "\n",
        "To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:\n",
        "\n",
        "* https://www.kaggle.com/models/google/paligemma/\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "zGLIp1Cx3_CX"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gQNOTfF24AV4",
        "outputId": "54f8aeed-bdbd-4ab3-941b-373392591505"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading the checkpoint from Kaggle, this could take a few minutes....\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Downloading from https://www.kaggle.com/api/v1/models/google/paligemma/jax/paligemma-3b-pt-224/1/download/paligemma-3b-pt-224.f16.npz...\n",
            "100%|██████████| 5.45G/5.45G [01:00<00:00, 95.9MB/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model path: /root/.cache/kagglehub/models/google/paligemma/jax/paligemma-3b-pt-224/1/./paligemma-3b-pt-224.f16.npz\n",
            "Downloading the model tokenizer...\n",
            "Copying gs://big_vision/paligemma_tokenizer.model...\n",
            "- [1 files][  4.1 MiB/  4.1 MiB]                                                \n",
            "Operation completed over 1 objects/4.1 MiB.                                      \n",
            "Tokenizer path: ./paligemma_tokenizer.model\n",
            "Downloading the dataset...\n",
            "Data path: ./longcap100\n"
          ]
        }
      ],
      "source": [
        "# @title Download checkpoint, tokenizer and dataset to local filesystem.\n",
        "#\n",
        "import os\n",
        "import kagglehub\n",
        "\n",
        "MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
        "if not os.path.exists(MODEL_PATH):\n",
        "  print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
        "  # Note: kaggle archive contains the same checkpoint in multiple formats.\n",
        "  # Download only the float16 model.\n",
        "  MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', MODEL_PATH)\n",
        "  print(f\"Model path: {MODEL_PATH}\")\n",
        "\n",
        "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
        "if not os.path.exists(TOKENIZER_PATH):\n",
        "  print(\"Downloading the model tokenizer...\")\n",
        "  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n",
        "  print(f\"Tokenizer path: {TOKENIZER_PATH}\")\n",
        "\n",
        "DATA_DIR=\"./longcap100\"\n",
        "if not os.path.exists(DATA_DIR):\n",
        "  print(\"Downloading the dataset...\")\n",
        "  !gsutil -m -q cp -n -r gs://longcap100/ .\n",
        "  print(f\"Data path: {DATA_DIR}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDoq0O77GF30"
      },
      "source": [
        "## Notebook"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dTfe2k8J4Bw0",
        "outputId": "b9864437-9e35-493a-bf52-019c18d5dfd9"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "JAX version:  0.4.26\n",
            "JAX platform: gpu\n",
            "JAX devices:  1\n"
          ]
        }
      ],
      "source": [
        "import base64\n",
        "import functools\n",
        "import html\n",
        "import io\n",
        "import os\n",
        "import warnings\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import ml_collections\n",
        "\n",
        "import tensorflow as tf\n",
        "import sentencepiece\n",
        "\n",
        "from IPython.core.display import display, HTML\n",
        "from PIL import Image\n",
        "\n",
        "# Import model definition from big_vision\n",
        "from big_vision.models.proj.paligemma import paligemma\n",
        "from big_vision.trainers.proj.paligemma import predict_fns\n",
        "\n",
        "# Import big vision utilities\n",
        "import big_vision.datasets.jsonl\n",
        "import big_vision.utils\n",
        "import big_vision.sharding\n",
        "\n",
        "# Don't let TF use the GPU or TPUs\n",
        "tf.config.set_visible_devices([], \"GPU\")\n",
        "tf.config.set_visible_devices([], \"TPU\")\n",
        "\n",
        "backend = jax.lib.xla_bridge.get_backend()\n",
        "print(f\"JAX version:  {jax.__version__}\")\n",
        "print(f\"JAX platform: {backend.platform}\")\n",
        "print(f\"JAX devices:  {jax.device_count()}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "1aghcULcEdtv"
      },
      "outputs": [],
      "source": [
        "# @title Construct model and load params into RAM.\n",
        "\n",
        "# Define model\n",
        "model_config = ml_collections.FrozenConfigDict({\n",
        "    \"llm\": {\"vocab_size\": 257_152},\n",
        "    \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
        "})\n",
        "model = paligemma.Model(**model_config)\n",
        "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n",
        "\n",
        "# Load params - this can take up to 1 minute in T4 colabs.\n",
        "params = paligemma.load(None, MODEL_PATH, model_config)\n",
        "\n",
        "# Define `decode` function to sample outputs from the model.\n",
        "decode_fn = predict_fns.get_all(model)['decode']\n",
        "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RWOdf_fw2SAO",
        "outputId": "6d48433f-7410-480d-b889-e2b679caa8a6"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            " == Model params == \n",
            "img/Transformer/encoder_norm/bias                                                (1152,)                float16\n",
            "img/Transformer/encoder_norm/scale                                               (1152,)                float16\n",
            "img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16\n",
            "img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel                           (27, 4304, 1152)       float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias             (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel           (27, 1152, 16, 72)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias             (27, 1152)             float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel           (27, 16, 72, 1152)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias           (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel         (27, 1152, 16, 72)     float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias           (27, 16, 72)           float16\n",
            "img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel         (27, 1152, 16, 72)     float16\n",
            "img/embedding/bias                                                               (1152,)                float16\n",
            "img/embedding/kernel                                                             (14, 14, 3, 1152)      float16\n",
            "img/head/bias                                                                    (2048,)                float16\n",
            "img/head/kernel                                                                  (1152, 2048)           float16\n",
            "img/pos_embedding                                                                (1, 256, 1152)         float16\n",
            "llm/embedder/input_embedding                                                     (257152, 2048)         float16\n",
            "llm/final_norm/scale                                                             (2048,)                float16\n",
            "llm/layers/attn/attn_vec_einsum/w                                                (18, 8, 256, 2048)     float32\n",
            "llm/layers/attn/kv_einsum/w                                                      (18, 2, 1, 2048, 256)  float32\n",
            "llm/layers/attn/q_einsum/w                                                       (18, 8, 2048, 256)     float32\n",
            "llm/layers/mlp/gating_einsum                                                     (18, 2, 2048, 16384)   float16\n",
            "llm/layers/mlp/linear                                                            (18, 16384, 2048)      float16\n",
            "llm/layers/pre_attention_norm/scale                                              (18, 2048)             float16\n",
            "llm/layers/pre_ffw_norm/scale                                                    (18, 2048)             float16\n"
          ]
        }
      ],
      "source": [
        "# @title Move params to GPU/TPU memory.\n",
        "#\n",
        "# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune\n",
        "# a part of the parameters. Additionally we keep the frozen params in float16\n",
        "# and cast trainable to float32.\n",
        "\n",
        "# Create a pytree mask of the trainable params.\n",
        "def is_trainable_param(name, param):  # pylint: disable=unused-argument\n",
        "  if name.startswith(\"llm/layers/attn/\"):  return True\n",
        "  if name.startswith(\"llm/\"):              return False\n",
        "  if name.startswith(\"img/\"):              return False\n",
        "  raise ValueError(f\"Unexpected param name {name}\")\n",
        "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n",
        "\n",
        "#\n",
        "# If more than one device is available (e.g. multiple GPUs) the parameters can\n",
        "# be sharded across them to reduce HBM usage per device.\n",
        "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n",
        "\n",
        "data_sharding = jax.sharding.NamedSharding(\n",
        "    mesh, jax.sharding.PartitionSpec(\"data\"))\n",
        "\n",
        "params_sharding = big_vision.sharding.infer_sharding(\n",
        "    params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n",
        "\n",
        "# Yes: Some donated buffers are not usable.\n",
        "warnings.filterwarnings(\n",
        "    \"ignore\", message=\"Some donated buffers were not usable\")\n",
        "\n",
        "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
        "def maybe_cast_to_f32(params, trainable):\n",
        "  return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n",
        "                      params, trainable)\n",
        "\n",
        "# Loading all params in simultaneous - albeit much faster and more succinct -\n",
        "# requires more RAM than the T4 colab runtimes have by default (12GB RAM).\n",
        "# Instead we do it param by param.\n",
        "params, treedef = jax.tree.flatten(params)\n",
        "sharding_leaves = jax.tree.leaves(params_sharding)\n",
        "trainable_leaves = jax.tree.leaves(trainable_mask)\n",
        "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n",
        "  params[idx] = big_vision.utils.reshard(params[idx], sharding)\n",
        "  params[idx] = maybe_cast_to_f32(params[idx], trainable)\n",
        "  params[idx].block_until_ready()\n",
        "params = jax.tree.unflatten(treedef, params)\n",
        "\n",
        "# Print params to show what the model is made of.\n",
        "def parameter_overview(params):\n",
        "  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n",
        "    print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n",
        "\n",
        "print(\" == Model params == \")\n",
        "parameter_overview(params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "8SRW0NuU4UcW"
      },
      "outputs": [],
      "source": [
        "# @title Define preprocess functions to create inputs to the model.\n",
        "\n",
        "def preprocess_image(image, size=224):\n",
        "  # Model has been trained to handle images of different aspects ratios\n",
        "  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n",
        "  # options are helpful to improve quality in some tasks.\n",
        "  image = np.asarray(image)\n",
        "  if image.ndim == 2:  # Convert image without last channel into greyscale.\n",
        "    image = np.stack((image,)*3, axis=-1)\n",
        "  image = image[..., :3]  # Remove alpha layer.\n",
        "  assert image.shape[-1] == 3\n",
        "\n",
        "  image = tf.constant(image)\n",
        "  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n",
        "  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]\n",
        "\n",
        "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n",
        "  # Model has been trained to handle tokenized text composed of a prefix with\n",
        "  # full attention and a suffix with causal attention.\n",
        "  separator = \"\\n\"\n",
        "  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n",
        "  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.\n",
        "  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.\n",
        "\n",
        "  if suffix:\n",
        "    suffix = tokenizer.encode(suffix, add_eos=True)\n",
        "    tokens += suffix\n",
        "    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.\n",
        "    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.\n",
        "\n",
        "  mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.\n",
        "  if seqlen:\n",
        "    padding = [0] * max(0, seqlen - len(tokens))\n",
        "    tokens = tokens[:seqlen] + padding\n",
        "    mask_ar = mask_ar[:seqlen] + padding\n",
        "    mask_loss = mask_loss[:seqlen] + padding\n",
        "    mask_input = mask_input[:seqlen] + padding\n",
        "\n",
        "  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n",
        "\n",
        "def postprocess_tokens(tokens):\n",
        "  tokens = tokens.tolist()  # np.array to list[int]\n",
        "  try:  # Remove tokens at and after EOS if any.\n",
        "    eos_pos = tokens.index(tokenizer.eos_id())\n",
        "    tokens = tokens[:eos_pos]\n",
        "  except ValueError:\n",
        "    pass\n",
        "  return tokenizer.decode(tokens)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "whzWOojGOtzi"
      },
      "outputs": [],
      "source": [
        "# @title Function to iterate over train and validation examples.\n",
        "SEQLEN = 128\n",
        "\n",
        "# TODO: Consider data iterators skipping big_vision and tf.data?\n",
        "train_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_train90.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "val_dataset = big_vision.datasets.jsonl.DataSource(\n",
        "    os.path.join(DATA_DIR, \"data_val10.jsonl\"),\n",
        "    fopen_keys={\"image\": DATA_DIR})\n",
        "\n",
        "\n",
        "def train_data_iterator():\n",
        "  \"\"\"Never ending iterator over training examples.\"\"\"\n",
        "  # Shuffle examples and repeat so one can train for many epochs.\n",
        "  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n",
        "  for example in dataset.as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    suffix = example[\"suffix\"].decode().lower()\n",
        "    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_loss\": np.asarray(mask_loss),\n",
        "    }\n",
        "\n",
        "\n",
        "def validation_data_iterator():\n",
        "  \"\"\"Single iterator over validation examples.\"\"\"\n",
        "  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n",
        "    image = Image.open(io.BytesIO(example[\"image\"]))\n",
        "    image = preprocess_image(image)\n",
        "\n",
        "    prefix = \"caption en\"  # Could also be a different prefix per example.\n",
        "    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n",
        "\n",
        "    yield {\n",
        "        \"image\": np.asarray(image),\n",
        "        \"text\": np.asarray(tokens),\n",
        "        \"mask_ar\": np.asarray(mask_ar),\n",
        "        \"mask_input\": np.asarray(mask_input),\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 516
        },
        "id": "BzJfb5t0nsLq",
        "outputId": "1f6640f7-09b4-41a3-c713-62966b0df7e7"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Training examples\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a tall crane stands in the heart of the city, casting long shadows across the streets below. the sky is clear and blue, with fluffy white clouds drifting lazily. a tall white building dominates the skyline, its windows reflecting the afternoon sun. a tall black building casts a long shadow on the ground. </p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a neon sign on a brick wall reads &quot;this is the sign you ve been looking for.&quot; the sign is lit up and the letters are white. there are several pillows on the wall, including a pillow with a skull and crossbones. the wall is made of bricks and the sign is on the wall. the sign is neon and the letters are white.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a tool box filled with a variety of tools, including a wrench with a silver head, a screwdriver with a gray handle,a wrench with a gray head, a screwdriver with a gray handle, a metal socket with a silver head...</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person stands on a sidewalk, their shoes and legs visible. the ground is made of concrete. the person wears black pants, with a white sole on their shoe and a white lace on their shoe. the shoes are black and white. the person&#x27;s legs are visible. the words &quot;passion led us here&quot; are written on the ground in red. the concrete has a shadow on it, and the sun shines on the ground</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a bowl of steaming instant noodle soup with a spoon resting in the center. the broth is clear and the vegetables, including carrots, peas, and green beans, are floating gently in the liquid. the spoon is long and silver, with a reflection of light on its handle. the overall image is simple and straightforward, with a focus on the deliciousness of the soup.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a kitchen with a stove and a window. the room has a green floor and a white wall. there is a white towel hanging on the oven door. the stove has a black oven door, a black knob on the stove, and a black and white oven. there is a a silver pot on the stove. the window has a yellow frame and there is a white plastic bag hanging on the oven door. the door is open.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a plate of colorful desserts and a cup of coffee. the plate features a variety of sweet treats, including a macaron with a bite taken out, a green macaron, a pink macaron, and a white macaron. the plate is adorned with a white flower on a tree branch and a green leaf on a plant. the coffee cup has a white handle. the table is covered in a white and gold tablecloth.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a car is parked in front of a building with a green umbrella. the building has a green and white sign, a green and white umbrella, and a white sign with black lettering. there is a tall palm tree and a tall tree. the car is parked next to a yellow car and a black car. the road is grey and the sky is white.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        }
      ],
      "source": [
        "# @title Inspect training examples.\n",
        "def render_inline(image, resize=(128, 128)):\n",
        "  \"\"\"Convert image into inline html.\"\"\"\n",
        "  image = Image.fromarray(image)\n",
        "  image.resize(resize)\n",
        "  with io.BytesIO() as buffer:\n",
        "    image.save(buffer, format='jpeg')\n",
        "    image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n",
        "    return f\"data:image/jpeg;base64,{image_b64}\"\n",
        "\n",
        "def render_example(image, caption):\n",
        "  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]\n",
        "  return f\"\"\"\n",
        "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
        "        <img style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" />\n",
        "        <p style=\"width:256px; margin:10px; font-size:small;\">{html.escape(caption)}</p>\n",
        "    </div>\n",
        "    \"\"\"\n",
        "\n",
        "html_out = \"\"\n",
        "for idx, example in zip(range(8), train_data_iterator()):\n",
        "  caption = postprocess_tokens(example[\"text\"])  # detokenize model input.\n",
        "  caption = caption[len(\"caption en\\n\"):]        # strip prefix\n",
        "  html_out += render_example(example[\"image\"], caption)\n",
        "\n",
        "print(\"Training examples\")\n",
        "display(HTML(html_out))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "dwUV_imW3WQJ"
      },
      "outputs": [],
      "source": [
        "# @title Define the training step and evaluation loop.\n",
        "#\n",
        "# The main update_fn using simple SGD.\n",
        "#\n",
        "@functools.partial(jax.jit, donate_argnums=(0,))\n",
        "def update_fn(params, batch, learning_rate):\n",
        "  imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n",
        "\n",
        "  def loss_fn(params):\n",
        "    text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n",
        "    logp = jax.nn.log_softmax(text_logits, axis=-1)\n",
        "\n",
        "    # The model takes as input txts[:, :-1] but the loss is defined as predicting\n",
        "    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n",
        "    # are part of the loss (e.g. prefix and padded tokens are not included).\n",
        "    mask_loss = batch[\"mask_loss\"][:, 1:]\n",
        "    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
        "\n",
        "    # Compute the loss per example. i.e. the mean of per token pplx.\n",
        "    # Since each example has a different number of tokens we normalize it.\n",
        "    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.\n",
        "    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.\n",
        "    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.\n",
        "\n",
        "    # batch_loss: mean of per example loss.\n",
        "    return jnp.mean(example_loss)\n",
        "\n",
        "  loss, grads = jax.value_and_grad(loss_fn)(params)\n",
        "\n",
        "  # Apply gradients to trainable params using SGD.\n",
        "  def apply_grad(param, gradient, trainable):\n",
        "    if not trainable: return param\n",
        "    return param - learning_rate * gradient\n",
        "\n",
        "  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n",
        "\n",
        "  return params, loss\n",
        "\n",
        "# Evaluation/inference loop.\n",
        "def make_predictions(data_iterator, *, num_examples=None,\n",
        "                     batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n",
        "  outputs = []\n",
        "  while True:\n",
        "    # Construct a list of examples in the batch.\n",
        "    examples = []\n",
        "    try:\n",
        "      for _ in range(batch_size):\n",
        "        examples.append(next(data_iterator))\n",
        "        examples[-1][\"_mask\"] = np.array(True)  # Indicates true example.\n",
        "    except StopIteration:\n",
        "      if len(examples) == 0:\n",
        "        return outputs\n",
        "\n",
        "    # Not enough examples to complete a batch. Pad by repeating last example.\n",
        "    while len(examples) % batch_size:\n",
        "      examples.append(dict(examples[-1]))\n",
        "      examples[-1][\"_mask\"] = np.array(False)  # Indicates padding example.\n",
        "\n",
        "    # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "    batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "    batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "    # Make model predictions\n",
        "    tokens = decode({\"params\": params}, batch=batch,\n",
        "                    max_decode_len=seqlen, sampler=sampler)\n",
        "\n",
        "    # Fetch model predictions to device and detokenize.\n",
        "    tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n",
        "    tokens = tokens[mask]  # remove padding examples.\n",
        "    responses = [postprocess_tokens(t) for t in tokens]\n",
        "\n",
        "    # Append to html output.\n",
        "    for example, response in zip(examples, responses):\n",
        "      outputs.append((example[\"image\"], response))\n",
        "      if num_examples and len(outputs) >= num_examples:\n",
        "        return outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "067wj_6bZAG3",
        "outputId": "e1aa2df0-502e-4a70-c88d-db98739c01d5"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step:  1/64   lr: 0.00500   loss: 2.7898\n",
            "Model predictions at step 1\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">the beauty of a puff sleeve</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">how to wear a maxi dress for summer</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a red blazer and a black bag</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">how to wear boyfriend jeans like a fashion blogger</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step:  2/64   lr: 0.01000   loss: 2.1176\n",
            "step:  3/64   lr: 0.01500   loss: 1.7491\n",
            "step:  4/64   lr: 0.02000   loss: 1.5594\n",
            "step:  5/64   lr: 0.02500   loss: 1.6047\n",
            "step:  6/64   lr: 0.03000   loss: 1.3865\n",
            "step:  7/64   lr: 0.02998   loss: 1.4946\n",
            "step:  8/64   lr: 0.02992   loss: 1.6175\n",
            "step:  9/64   lr: 0.02981   loss: 1.3377\n",
            "step: 10/64   lr: 0.02966   loss: 1.4888\n",
            "step: 11/64   lr: 0.02947   loss: 1.3479\n",
            "step: 12/64   lr: 0.02924   loss: 1.3211\n",
            "step: 13/64   lr: 0.02897   loss: 1.0806\n",
            "step: 14/64   lr: 0.02866   loss: 1.1590\n",
            "step: 15/64   lr: 0.02831   loss: 1.1158\n",
            "step: 16/64   lr: 0.02792   loss: 1.1702\n",
            "Model predictions at step 16\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a pink blouse with a large puffy sleeve stands on a white wall. the woman&#x27;s hand rests on the wall, and her fingers are intertwined. the wall is white, and the light is shining on the woman&#x27;s hand. the sky is clear, and the sun is shining. the woman is wearing a pink blouse, and her hair is in a bun.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a white floral dress sits on a stone wall overlooking the ocean. the dress is flowing in the wind and the hat is on her head. the sky is clear and the sun is shining. the woman is wearing a hat and a bag.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer and a black belt bag. the person is standing in the woods and the grass is green. the person is wearing a black belt bag and a black belt. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag. the person is wearing a red blazer and a black belt bag</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a pink shirt and jeans stands on a stone staircase. she is holding a pink bag and wearing a bracelet. the stairs are made of stone and the woman is wearing a bracelet.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step: 17/64   lr: 0.02750   loss: 1.1610\n",
            "step: 18/64   lr: 0.02704   loss: 1.1972\n",
            "step: 19/64   lr: 0.02655   loss: 1.1947\n",
            "step: 20/64   lr: 0.02602   loss: 1.3566\n",
            "step: 21/64   lr: 0.02546   loss: 1.1505\n",
            "step: 22/64   lr: 0.02488   loss: 1.1470\n",
            "step: 23/64   lr: 0.02426   loss: 0.9735\n",
            "step: 24/64   lr: 0.02362   loss: 1.1087\n",
            "step: 25/64   lr: 0.02296   loss: 0.9770\n",
            "step: 26/64   lr: 0.02227   loss: 1.0618\n",
            "step: 27/64   lr: 0.02156   loss: 0.9121\n",
            "step: 28/64   lr: 0.02083   loss: 0.9501\n",
            "step: 29/64   lr: 0.02009   loss: 0.9369\n",
            "step: 30/64   lr: 0.01933   loss: 1.0276\n",
            "step: 31/64   lr: 0.01856   loss: 0.9005\n",
            "step: 32/64   lr: 0.01778   loss: 0.8751\n",
            "Model predictions at step 32\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a pink blouse with a white wall in the background. the blouse has a white collar and a white wall. the person is standing and wearing a white shirt. the wall is white and the person is standing on a white step. the person is wearing a white shirt and a white wall. the person is wearing a pink blouse and a white wall.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a white floral dress with a brown belt, holding a white bag. the dress has a v-neckline and short sleeves. the woman is standing on a stone wall overlooking the ocean. the sky is clear and blue, with a few white clouds. the water is calm and blue, with a few white boats on the horizon. the woman is wearing a brown belt and a brown bag. the woman&#x27;s hair is in a ponytail.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing in front of a green plant. the plant has a green leaf and a green stem. the person is wearing a black shirt and black pants. the person is wearing a black belt and a black fanny pack. the person is wearing a red blazer and a red shirt. the person is wearing a black jacket and a black pants. the person is wearing a black jacket and a black pants. the person is wearing a black jacket and a black pants. the person is wearing a</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a pink shirt and blue jeans, holding a pink bag. the woman is standing on a stone staircase, wearing a white cardigan and a gold bracelet. the bag is pink and has a black strap. the woman is wearing a gold bracelet and a gold bracelet on her wrist. the woman is standing on a stone staircase.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step: 33/64   lr: 0.01699   loss: 0.8570\n",
            "step: 34/64   lr: 0.01620   loss: 1.0866\n",
            "step: 35/64   lr: 0.01540   loss: 0.7469\n",
            "step: 36/64   lr: 0.01460   loss: 0.8605\n",
            "step: 37/64   lr: 0.01380   loss: 0.6830\n",
            "step: 38/64   lr: 0.01301   loss: 0.7631\n",
            "step: 39/64   lr: 0.01222   loss: 0.7814\n",
            "step: 40/64   lr: 0.01144   loss: 0.8579\n",
            "step: 41/64   lr: 0.01067   loss: 0.7825\n",
            "step: 42/64   lr: 0.00991   loss: 0.6906\n",
            "step: 43/64   lr: 0.00917   loss: 0.7922\n",
            "step: 44/64   lr: 0.00844   loss: 0.7030\n",
            "step: 45/64   lr: 0.00773   loss: 0.7501\n",
            "step: 46/64   lr: 0.00704   loss: 0.6384\n",
            "step: 47/64   lr: 0.00638   loss: 0.6309\n",
            "step: 48/64   lr: 0.00574   loss: 0.6149\n",
            "Model predictions at step 48\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing on a white step, and the sun is shining on the wall. the person is wearing a bracelet on their wrist, and their hand is on the step. the person is wearing a watch on their wrist, and their nails are painted red.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wears a white floral dress with a brown belt. the dress has a v-neckline and short sleeves. the woman is standing on a stone wall overlooking the ocean. the sky is clear and blue, and the boats are visible on the horizon. the woman is holding a white wicker bag. the woman&#x27;s hair is tied back.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the person is standing and the grass is green.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of stone and the woman is standing on the stairs.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "step: 49/64   lr: 0.00512   loss: 0.7896\n",
            "step: 50/64   lr: 0.00454   loss: 0.6380\n",
            "step: 51/64   lr: 0.00398   loss: 0.6263\n",
            "step: 52/64   lr: 0.00345   loss: 0.6160\n",
            "step: 53/64   lr: 0.00296   loss: 0.6626\n",
            "step: 54/64   lr: 0.00250   loss: 0.5598\n",
            "step: 55/64   lr: 0.00208   loss: 0.5567\n",
            "step: 56/64   lr: 0.00169   loss: 0.7069\n",
            "step: 57/64   lr: 0.00134   loss: 0.5293\n",
            "step: 58/64   lr: 0.00103   loss: 0.5725\n",
            "step: 59/64   lr: 0.00076   loss: 0.5477\n",
            "step: 60/64   lr: 0.00053   loss: 0.6153\n",
            "step: 61/64   lr: 0.00034   loss: 0.5260\n",
            "step: 62/64   lr: 0.00019   loss: 0.5673\n",
            "step: 63/64   lr: 0.00008   loss: 0.5721\n",
            "step: 64/64   lr: 0.00002   loss: 0.6681\n",
            "Model predictions at step 64\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing and their hand is on the wall. the wall has a white line on it. the person is wearing a pink blouse with a puffy sleeve.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a white floral dress stands on a pier overlooking the ocean. she is wearing a brown bag and a brown hat. the dress has a v-neckline and a tie belt. the woman has her hand in her pocket and her leg is visible. the sky is clear and blue.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the tree behind them is green.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of gray stone. the woman is standing on the stairs.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CPU times: user 12min 34s, sys: 6.31 s, total: 12min 40s\n",
            "Wall time: 13min 5s\n"
          ]
        }
      ],
      "source": [
        "# @title Run training loop.\n",
        "#\n",
        "# Run a short training loop with cosine learning rate schedule.\n",
        "#\n",
        "# Note: the first step can be quite slow on some machines (up to several minutes)\n",
        "# due to XLA compilation of the jax.jit'd function.\n",
        "#\n",
        "%%time\n",
        "\n",
        "BATCH_SIZE = 8\n",
        "TRAIN_EXAMPLES = 512\n",
        "LEARNING_RATE = 0.03\n",
        "\n",
        "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n",
        "EVAL_STEPS = TRAIN_STEPS // 4\n",
        "\n",
        "train_data_it = train_data_iterator()\n",
        "\n",
        "sched_fn = big_vision.utils.create_learning_rate_schedule(\n",
        "    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n",
        "    decay_type=\"cosine\", warmup_percent=0.10)\n",
        "\n",
        "for step in range(1, TRAIN_STEPS+1):\n",
        "  # Make list of N training examples.\n",
        "  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n",
        "\n",
        "  # Convert list of examples into a dict of np.arrays and load onto devices.\n",
        "  batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n",
        "  batch = big_vision.utils.reshard(batch, data_sharding)\n",
        "\n",
        "  # Training step and report training loss\n",
        "  learning_rate = sched_fn(step)\n",
        "  params, loss = update_fn(params, batch, learning_rate)\n",
        "\n",
        "  loss = jax.device_get(loss)\n",
        "  print(f\"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}\")\n",
        "\n",
        "  if step == 1 or (step % EVAL_STEPS) == 0:\n",
        "    print(f\"Model predictions at step {step}\")\n",
        "    html_out = \"\"\n",
        "    for image, caption in make_predictions(\n",
        "        validation_data_iterator(), num_examples=4, batch_size=4):\n",
        "      html_out += render_example(image, caption)\n",
        "    display(HTML(html_out))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 699
        },
        "id": "hgUhEKjzPdMQ",
        "outputId": "63037cd6-151c-4802-9de8-be2cb7818d12"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Model predictions\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a pink blouse with a puffy sleeve. the blouse has a white wall in the background. the person is standing and their hand is on the wall. the wall has a white line on it. the person is wearing a pink blouse with a puffy sleeve.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wearing a white floral dress stands on a pier overlooking the ocean. she is wearing a brown bag and a brown hat. the dress has a v-neckline and a tie belt. the woman has her hand in her pocket and her leg is visible. the sky is clear and blue.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a person wearing a red blazer with a black fanny pack around their waist. the blazer has a white button down and a black belt. the person is standing and wearing a black shirt. the fanny pack is black and has a white loading on it. the person is wearing black pants and black shoes. the tree behind them is green.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wears a pink shirt and blue jeans. she has a pink bag on her shoulder and a bracelet on her wrist. the stairs are made of gray stone. the woman is standing on the stairs.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a pink sweatshirt with a red slogan lies on a bed next to a pair of jeans and a pair of white sneakers. the sweatshirt features long sleeves and a crew neck. the text on the sweatshirt reads &quot;love well, save us.&quot; the sneakers are white and have white laces. the hand on the sweatshirt is gentle.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a man with long blonde hair covers his face with his hand. he wears a navy sweater and a black and white checkered shirt. the sweater has long sleeves and a collar. the man has a beard and mustache. the hair on his face is long and wavy. the man is standing and his hands are on his head. the background is pink.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a white metal rack with a white metal pole and a white metal pole. the rack has a white metal pole and a white metal pole. the rack has a white metal pole and a white metal pole.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a white hoodie hangs on a black coat rack, with a white drawstring on the left side of the hoodie. the coat rack is black, and the wall is white. there is a black circle on the wall, and a black circle on the wall. the hoodie has a white drawstring on the left side of the hoodie.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a woman wears a pair of blue jeans with a black bag. the jeans have a gray wall behind them. the woman is wearing a black bag with a gold chain. the bag has a silver chain and a silver lock. the woman is wearing black boots with black socks. the bag has a silver chain and a silver lock. the woman is wearing a black bracelet on her wrist. the woman is standing and the bag is on her shoulder.</p>\n",
              "    </div>\n",
              "    \n",
              "    <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
              "        <img style=\"width:128px; height:128px;\" src=\"\" />\n",
              "        <p style=\"width:256px; margin:10px; font-size:small;\">a man stands on a sidewalk wearing a blue denim jacket with a white t-shirt, brown pants, and white shoes. he has his hands in his pockets and his legs are stretched out. the jacket has a blue collar and a white t-shirt. the pants have a brown stripe on the side and a white shoe. the man is wearing a white t-shirt and brown pants.</p>\n",
              "    </div>\n",
              "    "
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "CPU times: user 26.2 s, sys: 112 ms, total: 26.3 s\n",
            "Wall time: 32.9 s\n"
          ]
        }
      ],
      "source": [
        "# @title Evaluate the model on all examples.\n",
        "#\n",
        "# The validation data consists of 10 images in a different domain than training\n",
        "# data.\n",
        "%%time\n",
        "\n",
        "print(\"Model predictions\")\n",
        "html_out = \"\"\n",
        "for image, caption in make_predictions(validation_data_iterator(), batch_size=4):\n",
        "  html_out += render_example(image, caption)\n",
        "display(HTML(html_out))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ai0NMbAwsr0j"
      },
      "source": [
        "# Save the final checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5H_3CV33_JkV"
      },
      "outputs": [],
      "source": [
        "def npsave(pytree, path):\n",
        "  names_and_vals, _ = big_vision.utils.tree_flatten_with_names(pytree)\n",
        "  with open(path, \"wb\") as f:\n",
        "    np.savez(f, **{k:v for k, v in names_and_vals})\n",
        "\n",
        "# Takes around 4 minutes\n",
        "npsave(params, 'my-custom-paligemma-ckpt.npz')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}